Conversation
|
I don't know how julia> gradient(d -> d["one"]*2 + d["one"]*3, Dict("one"=>1, "two"=>2))
(a, b) = (Dict{Any, Any}("one" => 5.0), Dict{Any, Any}("one" => 5.0))
(Dict{Any, Any}("one" => 5.0),)where I've added |
|
This is the issue from the Slack thread: JuliaDiff/ChainRules.jl#662. Note that testing that issue with this PR requires the Molly master branch. |
That's correct. Accumulation for Dicts (and mutable structs) is pretty weird in Zygote because of the need to support Line 36 in 5c80f55 |
|
We could probably add that special case of |
|
Ideally CR wouldn't have to make any changes here, will have to look into the PR + original issue in more depth though.
It should be doing this. If there's any rule where it's not then I think that should be called a bug. |
|
The z2d conversion looks good. All that's left other than tests is the other end in |
Here it doesn't seem to Lines 36 to 37 in 5c80f55 |
Oh true, I forgot that. |
|
Ah, you are right and I was mixing up |
|
After testing locally with the line in #1288 (comment) returning Lines 585 to 586 in 4183226 Were there more tests taking differentiating wrt a Note that this is just for Dicts. Doing the same for |
|
For reference, this is the test that fails when one removes the accum overload linked up top: Again, the only reason there's a single failure is because there is only one test for nested diff with implicit params. Is there perhaps a way for us to constrain this accum call so that it's only valid in that circumstance? |
Can someone link me to the issues about this?
This path isn't well explored in ChainRules, but it is defined.
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/fbb4936204cb1d857c2dd41eac4bd7bf497771b2/src/tangent_types/tangent.jl#L56-L58
This is the code Zygote uses for accumulating,
Zygote.jl/src/lib/base.jl
Lines 26 to 29 in c335d6d
which actually looks like it is different to what ChainRulesCore will do?
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/fbb4936204cb1d857c2dd41eac4bd7bf497771b2/src/tangent_types/tangent.jl#L334
TODO: